from ..smp import *
from ..dataset import img_root_map, DATASET_TYPE
from abc import abstractmethod


class BaseModel:

    INTERLEAVE = False
    allowed_types = ["text", "image", "video"]

    def __init__(self):
        self.dump_image_func = None

    def use_custom_prompt(self, dataset):
        """Whether to use custom prompt for the given dataset.

        Args:
            dataset (str): The name of the dataset.

        Returns:
            bool: Whether to use custom prompt. If True, will call `build_prompt` of the VLM to build the prompt.
                Default to False.
        """
        return False

    @abstractmethod
    def build_prompt(self, line, dataset):
        """Build custom prompts for a specific dataset. Called only if `use_custom_prompt` returns True.

        Args:
            line (line of pd.DataFrame): The raw input line.
            dataset (str): The name of the dataset.

        Returns:
            str: The built message.
        """
        raise NotImplementedError

    def set_dump_image(self, dump_image_func):
        self.dump_image_func = dump_image_func

    def dump_image(self, line, dataset):
        return self.dump_image_func(line)

    @abstractmethod
    def generate_inner(self, message, dataset=None):
        raise NotImplementedError

    def check_content(self, msgs):
        """Check the content type of the input. Four types are allowed: str, dict, liststr, listdict."""
        if isinstance(msgs, str):
            return "str"
        if isinstance(msgs, dict):
            return "dict"
        if isinstance(msgs, list):
            types = [self.check_content(m) for m in msgs]
            if all(t == "str" for t in types):
                return "liststr"
            if all(t == "dict" for t in types):
                return "listdict"
        return "unknown"

    def preproc_content(self, inputs):
        """Convert the raw input messages to a list of dicts.

        Args:
            inputs: raw input messages.

        Returns:
            list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
        """
        if self.check_content(inputs) == "str":
            return [dict(type="text", value=inputs)]
        elif self.check_content(inputs) == "dict":
            assert "type" in inputs and "value" in inputs
            return [inputs]
        elif self.check_content(inputs) == "liststr":
            res = []
            for s in inputs:
                mime, pth = parse_file(s)
                if mime is None or mime == "unknown":
                    res.append(dict(type="text", value=s))
                else:
                    res.append(dict(type=mime.split("/")[0], value=pth))
            return res
        elif self.check_content(inputs) == "listdict":
            for item in inputs:
                assert "type" in item and "value" in item
                mime, s = parse_file(item["value"])
                if mime is None:
                    assert item["type"] == "text"
                else:
                    assert mime.split("/")[0] == item["type"]
                    item["value"] = s
            return inputs
        else:
            return None

    def generate(self, message, dataset=None):
        """Generate the output message.

        Args:
            message (list[dict]): The input message.
            dataset (str, optional): The name of the dataset. Defaults to None.

        Returns:
            str: The generated message.
        """
        assert self.check_content(message) in [
            "str",
            "dict",
            "liststr",
            "listdict",
        ], f"Invalid input type: {message}"
        message = self.preproc_content(message)
        assert message is not None and self.check_content(message) == "listdict"
        for item in message:
            assert (
                item["type"] in self.allowed_types
            ), f'Invalid input type: {item["type"]}'
        return self.generate_inner(message, dataset)

    def chat(self, messages, dataset=None):
        """The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
        assert hasattr(
            self, "chat_inner"
        ), "The API model should has the `chat_inner` method. "
        for msg in messages:
            assert isinstance(msg, dict) and "role" in msg and "content" in msg, msg
            assert self.check_content(msg["content"]) in [
                "str",
                "dict",
                "liststr",
                "listdict",
            ], msg
            msg["content"] = self.preproc_content(msg["content"])

        while len(messages):
            try:
                return self.chat_inner(messages, dataset=dataset)
            except Exception as e:
                logging.info(f"{type(e)}: {e}")
                messages = messages[1:]
                while len(messages) and messages[0]["role"] != "user":
                    messages = messages[1:]
                continue
        return "Chat Mode: Failed with all possible conversation turns."

    def message_to_promptimg(self, message, dataset=None):
        assert not self.INTERLEAVE
        model_name = self.__class__.__name__
        warnings.warn(
            f"Model {model_name} does not support interleaved input. "
            "Will use the first image and aggregated texts as prompt. "
        )
        num_images = len([x for x in message if x["type"] == "image"])
        if num_images == 0:
            prompt = "\n".join([x["value"] for x in message if x["type"] == "text"])
            image = None
        else:
            prompt = "\n".join([x["value"] for x in message if x["type"] == "text"])
            images = [x["value"] for x in message if x["type"] == "image"]
            if "BLINK" == dataset:
                image = concat_images_vlmeval(images, target_size=512)
            else:
                image = images[0]
        return prompt, image

    def message_to_promptvideo(self, message):
        if self.VIDEO_LLM:
            num_videos = len([x for x in message if x["type"] == "video"])
            if num_videos == 0:
                prompt = "\n".join([x["value"] for x in message if x["type"] == "text"])
                video = None
            else:
                prompt = "\n".join([x["value"] for x in message if x["type"] == "text"])
                video = [x["value"] for x in message if x["type"] == "video"][0]
            return prompt, video
        else:
            logging.critical("Model does not support video input.")
            raise NotImplementedError

    def message_to_promptvideo_withrole(self, message, dataset=None):
        if self.VIDEO_LLM:
            system, user, assistant, video_list = "", "", "", []
            for msg in message:
                if msg["type"] == "text":
                    if "role" in msg and msg["role"] == "system":
                        system += msg["value"]
                    elif "role" in msg and msg["role"] == "assistant":
                        assistant += msg["value"]
                    else:
                        user += msg["value"]
                elif msg["type"] == "video":
                    video_list.append(msg["value"])
            question = {"system": system, "user": user, "assistant": assistant}
            if assistant == "":
                if listinstr(["MCQ"], DATASET_TYPE(dataset)):
                    question["assistant"] = "Best Option: ("
                else:
                    del question["assistant"]
            if len(video_list) > 1:
                print(
                    "VLMEvalKit only support single video as input, take first video as input"
                )
            video = video_list[0]
            return question, video
        else:
            logging.critical("Model does not support video input.")
            raise NotImplementedError
